import csv


def load_data(filename):
    D = []
    with open(filename, encoding='utf-8') as f:
        for l in f:
            text, label = l.strip().split('\t')
            D.append((text, int(label)))
    return D


def load_csv_data(param, label_id):
    """

    :param param:
    :param label_id:
    :return:
    """
    D = []

    with open(param, encoding="utf-8") as file:
        data = csv.reader(file)
        for data_one in data:
            D.append([data_one[3], label_id[data_one[2]]])
    return D


def save_label_id(param):
    """

    :param param:
    :return:
    """
    data_label_id = {}
    index = 0
    with open(param, encoding="utf-8") as file:
        data = csv.reader(file)
        for data_one in data:
            if data_one[2] not in data_label_id.keys():
                data_label_id[data_one[2]] = index
                index += 1
    return data_label_id


def medical_data():
    label_id = save_label_id('medical_ten/train_data.csv')
    # 加载数据集
    train_data = load_csv_data('medical_ten/train_data.csv', label_id)
    valid_data = load_csv_data('medical_ten/train_data.csv', label_id)
    test_data = load_csv_data('medical_ten/train_data.csv', label_id)
    return train_data, valid_data, test_data
